package rfc3961

import (
	

	
)

const (
	prfconstant = "prf"
)

// DeriveRandom implements the RFC 3961 defined function: DR(Key, Constant) = k-truncate(E(Key, Constant, initial-cipher-state)).
//
// key: base key or protocol key. Likely to be a key from a keytab file.
//
// usage: a constant.
//
// n: block size in bits (not bytes) - note if you use something like aes.BlockSize this is in bytes.
//
// k: key length / key seed length in bits. Eg. for AES256 this value is 256.
//
// e: the encryption etype function to use.
func (,  []byte,  etype.EType) ([]byte, error) {
	 := .GetCypherBlockBitLength()
	 := .GetKeySeedBitLength()
	//Ensure the usage constant is at least the size of the cypher block size. Pass it through the nfold algorithm that will "stretch" it if needs be.
	 := Nfold(, )
	//k-truncate implemented by creating a byte array the size of k (k is in bits hence /8)
	 := make([]byte, /8)
	// Keep feeding the output back into the encryption function until it is no longer short than k.
	, ,  := .EncryptData(, )
	if  != nil {
		return , 
	}
	for  := copy(, );  < len(); {
		_, , _ = .EncryptData(, )
		 =  + copy([:], )
	}
	return , nil
}

// DeriveKey derives a key from the protocol key based on the usage and the etype's specific methods.
func (,  []byte,  etype.EType) ([]byte, error) {
	,  := .DeriveRandom(, )
	if  != nil {
		return nil, 
	}
	return .RandomToKey(), nil
}

// RandomToKey returns a key from the bytes provided according to the definition in RFC 3961.
func ( []byte) []byte {
	return 
}

// DES3RandomToKey returns a key from the bytes provided according to the definition in RFC 3961 for DES3 etypes.
func ( []byte) []byte {
	 := fixWeakKey(stretch56Bits([:7]))
	 := fixWeakKey(stretch56Bits([7:14]))
	 = append(, ...)
	 := fixWeakKey(stretch56Bits([14:21]))
	 = append(, ...)
	return 
}

// DES3StringToKey returns a key derived from the string provided according to the definition in RFC 3961 for DES3 etypes.
func (,  string,  etype.EType) ([]byte, error) {
	 :=  + 
	 := .RandomToKey(Nfold([]byte(), .GetKeySeedBitLength()))
	return .DeriveKey(, []byte("kerberos"))
}

// PseudoRandom function as defined in RFC 3961
func (,  []byte,  etype.EType) ([]byte, error) {
	 := .GetHashFunc()()
	.Write()
	 := .Sum(nil)[:.GetMessageBlockByteSize()]
	,  := .DeriveKey(, []byte(prfconstant))
	if  != nil {
		return []byte{}, 
	}
	, ,  := .EncryptData(, )
	if  != nil {
		return []byte{}, 
	}
	return , nil
}

func stretch56Bits( []byte) []byte {
	 := make([]byte, len(), len())
	copy(, )
	var  byte
	for ,  := range  {
		,  := calcEvenParity()
		[] = 
		if  != 0 {
			 =  | (1 << uint(+1))
		} else {
			 =  &^ (1 << uint(+1))
		}
	}
	_,  = calcEvenParity()
	 = append(, )
	return 
}

func calcEvenParity( byte) (uint8, uint8) {
	 :=  & 0x01
	// c counter of 1s in the first 7 bits of the byte
	var  int
	// Iterate over the highest 7 bits (hence p starts at 1 not zero) and count the 1s.
	for  := 1;  < 8; ++ {
		 :=  & (1 << uint())
		if  != 0 {
			++
		}
	}
	if %2 == 0 {
		//Even number of 1s so set parity to 1
		 =  | 1
	} else {
		//Odd number of 1s so set parity to 0
		 =  &^ 1
	}
	return , 
}

func fixWeakKey( []byte) []byte {
	if weak() {
		[7] ^= 0xF0
	}
	return 
}

func weak( []byte) bool {
	// weak keys from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-67r1.pdf
	 := [4][]byte{
		{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
		{0xFE, 0xFE, 0xFE, 0xFE, 0xFE, 0xFE, 0xFE, 0xFE},
		{0xE0, 0xE0, 0xE0, 0xE0, 0xF1, 0xF1, 0xF1, 0xF1},
		{0x1F, 0x1F, 0x1F, 0x1F, 0x0E, 0x0E, 0x0E, 0x0E},
	}
	 := [12][]byte{
		{0x01, 0x1F, 0x01, 0x1F, 0x01, 0x0E, 0x01, 0x0E},
		{0x1F, 0x01, 0x1F, 0x01, 0x0E, 0x01, 0x0E, 0x01},
		{0x01, 0xE0, 0x01, 0xE0, 0x01, 0xF1, 0x01, 0xF1},
		{0xE0, 0x01, 0xE0, 0x01, 0xF1, 0x01, 0xF1, 0x01},
		{0x01, 0xFE, 0x01, 0xFE, 0x01, 0xFE, 0x01, 0xFE},
		{0xFE, 0x01, 0xFE, 0x01, 0xFE, 0x01, 0xFE, 0x01},
		{0x1F, 0xE0, 0x1F, 0xE0, 0x0E, 0xF1, 0x0E, 0xF1},
		{0xE0, 0x1F, 0xE0, 0x1F, 0xF1, 0x0E, 0xF1, 0x0E},
		{0x1F, 0xFE, 0x1F, 0xFE, 0x0E, 0xFE, 0x0E, 0xFE},
		{0xFE, 0x1F, 0xFE, 0x1F, 0xFE, 0x0E, 0xFE, 0x0E},
		{0xE0, 0xFE, 0xE0, 0xFE, 0xF1, 0xFE, 0xF1, 0xFE},
		{0xFE, 0xE0, 0xFE, 0xE0, 0xFE, 0xF1, 0xFE, 0xF1},
	}
	for ,  := range  {
		if bytes.Equal(, ) {
			return true
		}
	}
	for ,  := range  {
		if bytes.Equal(, ) {
			return true
		}
	}
	return false
}